library(tictoc)
library(MDEI)
library(grf)
library(KRLS2)

## Functions used in the analysis----
options(device = "quartz")
b1.CIs <- NULL
plot.tsq <- function(pointest, CI) {
  cover.curr <- apply(CI - tau.true, 1, prod) < 0
  plot(
    treat,
    pointest,
    type = "n",
    ylim = range(c(m1$CIs.tau, b1.CIs)),
    xlab = "",
    ylab = ""
  )
  lines(treat, treat * 0 + tau.true)
  segments(
    x0 = treat,
    x1 = treat,
    y0 = CI[, 1],
    y1 = CI[, 2],
    col = ifelse(cover.curr,
                 gray(.7),
                 "black")
    
  )
  mtext("Estimated Effect", 1, line = 2.2, cex = 1.25)
  mtext("True Effect", 2, line = 2, cex = 1.25)
  
  points(treat, pointest, pch = 19, cex = .5)
  print(mean(cover.curr))
  
}


## Generate data and setup sim ----
n <- 1000
set.seed(1)
X <- matrix(rnorm(n * 5), nrow = n)
set.seed(1)
treat <- sort(rnorm(n))

theta.true <- treat ^ 2 / 2
theta.true <- theta.true - mean(theta.true)
tau.true <- treat

y <- theta.true  + rnorm(n, sd = 1)

tic()
set.seed(1)
m1 <- MDEI(y, treat, X, splits = 100, alpha = .9)
toc()


set.seed(1)
grf1 <- causal_forest(X, y, treat, num.trees = 10000)
predict.grf  <- predict(grf1, estimate.variance = TRUE)
grf.CIs <-
  predict.grf$predictions + 1.645 * cbind(-predict.grf$variance.estimates ^
                                            .5,
                                          predict.grf$variance.estimates ^ .5)

set.seed(1)
krls1 <- krls(y = y,
              X = cbind(treat, X),
              epsilon = 0.001)
sk1 <- inference.krls2(krls1)
k1.est <- (sk1$derivatives[, 1])
k1.se <- sk1$var.derivatives[, 1] ^ .5
k1.ci <- k1.est + cbind(-1.645 * k1.se, 1.645 * k1.se)

## Figure 2 ----
pdf("Figure2_tsqex1.pdf", h = 5, w = 6)
par(mar = c(3, 3.3, 1.3, .4))
plot(
  treat,
  m1$tau.est,
  type = "n",
  ylim = range(m1$CIs.tau),
  xlab = "",
  ylab = ""
)
lines(treat, treat)
lines(treat, treat ^ 2 / 2 - mean(treat ^ 2 / 2), lty = 2)
points(treat, y, pch = 19, cex = .5)
mtext(side = 3, "Data Setup", cex = 1.25)

legend.txt <-
  c(expression(paste("Conditional Mean, ", theta(T, X))),
    expression(paste("Effect Function, ", tau(T, X))))
legend("bottomright",
       legend = legend.txt,
       lty = 2:1,
       bty = "n",
       1:2)


dev.off()

## Figure 3 ----
pdf("Figure3_tsqex2.pdf", h = 4, w = 12)
par(mfrow = c(1, 3),
    mar = c(3, 3.3, 1.3, .4))

plot.tsq(m1$tau.est, m1$CIs.tau)
mtext(side = 3, "MDEI", cex = 1.25)

plot.tsq(grf1$predictions, grf.CIs)
mtext(side = 3, "GRF", cex = 1.25)

plot.tsq(k1.est, k1.ci)
mtext(side = 3, "KRLS", cex = 1.25)

dev.off()
